# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train model."""
import os
import numpy as np
from mindspore import Model, load_checkpoint, load_param_into_net, Tensor
from mindspore import nn
from mindspore.nn import TrainOneStepCell
from mindspore.train.callback import LossMonitor
from mindspore.train.callback import SummaryCollector

from bert_with_pooler.bert_zh import Model as BertWithPooler
from downstream_task import EmotionClassifier, LossCell, LearningRate
from utils.dataset import load_dataset
from utils.lr_scheduler import polynomial_decay_scheduler
from config.vocabulary import Vocabulary


def pad(ids: list, attn: list, tgt_len: int = 512):
    """Pad sentence to target length."""
    for i in range(tgt_len - len(ids)):
        ids.append(0)
        attn.append(0)
    return np.array([ids]).astype(np.int64), np.array([attn]).astype(np.int64)


def predict(classifier, input_ids, attn_mask, raw_text):
    """Infer for one sentence."""
    input_ids, attn_mask = pad(input_ids, attn_mask)
    token_type_ids = np.zeros(512).astype(np.int64).reshape(-1, 512)
    output = classifier(Tensor(input_ids), Tensor(attn_mask), Tensor(token_type_ids))
    idx = np.argmax(output.asnumpy())
    print("=" * 88)
    print(f"Raw Sentence: {raw_text}")
    print(f"Prob distribution: {output.asnumpy()}")
    print(f"It's {'positive' if idx == 1 else 'negative'}.")
    print("=" * 88)


def train(train_dataset, test_dataset, pretrained_model_ckpt_path, summary_dir, test_sentence=None):
    epoch = 10
    batch_size = 32
    # Load training dataset.
    ds_train = load_dataset(train_dataset, batch_size)
    # Load migrated bert model.
    pretrained_model = BertWithPooler()
    param_dict = load_checkpoint(pretrained_model_ckpt_path)
    not_load_params = load_param_into_net(pretrained_model, param_dict)
    assert not not_load_params, "Params is not fully loaded."

    # Define downstream task.
    classifier = EmotionClassifier(pretrained_model)
    classifier.set_train(True)

    # Define loss function here.
    loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
    model_with_loss = LossCell(classifier, loss_fn)

    # Define learning rate scheduler and optimizer here.
    learning_rate = polynomial_decay_scheduler(lr=1e-2, min_lr=1e-4, decay_steps=200,
                                               total_update_num=ds_train.get_dataset_size() * epoch,
                                               warmup_steps=100, power=1.2)
    optimizer = nn.Momentum(params=classifier.trainable_params(),
                            learning_rate=LearningRate(learning_rate),
                            momentum=0.9)
    train_model = TrainOneStepCell(model_with_loss, optimizer)

    model = Model(train_model)
    # TODO: Please create SummaryCollector to collect summary data,
    #  and add it to model.train() callback interface.
    summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=10)
    model.train(epoch, ds_train, callbacks=[summary_collector, LossMonitor(10)], dataset_sink_mode=False)

    # Eval model.
    eval_dataset = load_dataset(test_dataset)
    classifier.set_train(False)
    acc, total = 0, 0
    for sample in eval_dataset:
        input_ids, attn_mask, token_type_ids, label = sample[0], sample[1], sample[2], sample[3]
        logits = classifier(input_ids, attn_mask, token_type_ids)
        pred = np.argmax(logits.asnumpy(), axis=1).reshape(-1)
        matched = np.where(pred == label.asnumpy().reshape(-1))[0]
        acc += len(matched)
        total += input_ids.shape[0]
    print(f"Accuracy: {acc / total * 100:.2f}%")

    if test_sentence is not None:
        vocab_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config/vocab.txt")
        vocabs = Vocabulary.load_from_text(vocab_path)
        input_ids, attn_mask = vocabs.tokenize(test_sentence)
        predict(classifier, input_ids, attn_mask, test_sentence)
